"""
Compiles YAML prompts into Python code.
"""

import argparse
import inspect
from pathlib import Path
from typing import Literal

import yaml
from jinja2 import Template
from pydantic import BaseModel


# Based message class copied into the compiled module.
class PromptMessage(BaseModel):
    role: Literal["user"]
    content: str


# Base classification evaluator config class copied into the compiled module.
class ClassificationEvaluatorConfig(BaseModel):
    name: str
    description: str
    optimization_direction: Literal["minimize", "maximize"]
    messages: list[PromptMessage]
    choices: dict[str, float]


MODELS_TEMPLATE = """\
# This file is generated. Do not edit by hand.

from typing import Literal

from pydantic import BaseModel


{{ prompt_message_source }}

{{ classification_evaluator_config_source }}
"""

CLASSIFICATION_EVALUATOR_CONFIG_TEMPLATE = """\
# This file is generated. Do not edit by hand.
# ruff: noqa: E501

from ._models import ClassificationEvaluatorConfig, PromptMessage

{{ classification_evaluator_config_name }} = {{ classification_evaluator_config_definition }}
"""

INIT_TEMPLATE = """\
# This file is generated. Do not edit by hand.

from ._models import ClassificationEvaluatorConfig, PromptMessage
{% for name in prompt_names -%}
from ._{{ name.lower() }} import {{ name }}
{% endfor %}

__all__ = [
    "ClassificationEvaluatorConfig",
    "PromptMessage",
    {{ prompt_names|map('tojson')|join(', ') }}
]
"""


def get_models_file_contents() -> str:
    """
    Gets the contents of _models.py containing Pydantic model definitions.
    """
    template = Template(MODELS_TEMPLATE)
    prompt_message_source = inspect.getsource(PromptMessage).strip()
    classification_evaluator_config_source = inspect.getsource(
        ClassificationEvaluatorConfig
    ).strip()
    content = template.render(
        prompt_message_source=prompt_message_source,
        classification_evaluator_config_source=classification_evaluator_config_source,
    )
    return content


def get_prompt_file_contents(config: ClassificationEvaluatorConfig, name: str) -> str:
    """
    Gets the Python code contents for a ClassificationEvaluatorConfig.
    """
    template = Template(CLASSIFICATION_EVALUATOR_CONFIG_TEMPLATE)
    content = template.render(
        classification_evaluator_config_name=name,
        classification_evaluator_config_definition=repr(config),
    )
    return content


def get_init_file_contents(prompt_names: list[str]) -> str:
    """
    Gets the __init__.py file contents with exports for all prompts.
    """
    template = Template(INIT_TEMPLATE)
    content = template.render(prompt_names=prompt_names)
    return content


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compile YAML prompts to Python code")
    parser.add_argument(
        "compiled_module_path",
        type=Path,
        help="Path to the compiled module",
    )

    args = parser.parse_args()

    output_dir = args.compiled_module_path
    prompts_dir = Path("prompts/classification_evaluator_configs")

    # Ensure output directory exists
    output_dir.mkdir(parents=True, exist_ok=True)

    # Generate _models.py containing Pydantic model definitions
    models_content = get_models_file_contents()
    models_path = output_dir / "_models.py"
    models_path.write_text(models_content, encoding="utf-8")

    # Compile all YAML prompts to Python
    yaml_files = list(prompts_dir.glob("*.yaml"))
    prompt_names = []

    for yaml_file in sorted(yaml_files):
        # Read and validate YAML
        with open(yaml_file, "r", encoding="utf-8") as f:
            raw_config = yaml.safe_load(f)
        config = ClassificationEvaluatorConfig.model_validate(raw_config)

        # Generate Python code using YAML filename as the module/variable name
        name = yaml_file.stem
        content = get_prompt_file_contents(config, name)
        prompt_names.append(name)

        # Write to file
        output_path = output_dir / f"_{name.lower()}.py"
        output_path.write_text(content, encoding="utf-8")

    # Generate the __init__.py file
    init_content = get_init_file_contents(prompt_names)
    init_path = output_dir / "__init__.py"
    init_path.write_text(init_content, encoding="utf-8")
